/**
 * Copyright 2018 jianggujin (www.jianggujin.com).
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.jianggujin.codec;

import java.io.InputStream;
import java.io.OutputStream;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;

import javax.crypto.Cipher;
import javax.crypto.KeyAgreement;
import javax.crypto.SecretKey;
import javax.crypto.interfaces.DHPublicKey;
import javax.crypto.spec.DHParameterSpec;

import com.jianggujin.codec.util.JCipherInputStream;
import com.jianggujin.codec.util.JCipherOutputStream;
import com.jianggujin.codec.util.JCodecException;
import com.jianggujin.codec.util.JCodecUtils;

/**
 * Diffie-Hellman算法(D-H算法)，密钥一致协议。是由公开密钥密码体制的奠基人Diffie和Hellman所提出的一种思想。
 * 简单的说就是允许两名用户在公开媒体上交换信息以生成"一致"的、可以共享的密钥。换句话说，就是由甲方产出一对密钥（公钥、私钥），
 * 乙方依照甲方公钥产生乙方密钥对（公钥、私钥）。以此为基线，作为数据传输保密基础，同时双方使用同一种对称加密算法构建本地密钥（SecretKey）对数据加密
 * 。这样，在互通了本地密钥（SecretKey）算法后，甲乙双方公开自己的公钥，使用对方的公钥和刚才产生的私钥加密数据，
 * 同时可以使用对方的公钥和自己的私钥对数据解密。不单单是甲乙双方两方，可以扩展为多方共享数据通讯，这样就完成了网络交互数据的安全通讯！
 * 该算法源于中国的同余定理——中国馀数定理
 * <ol>
 * <li>甲方构建密钥对儿，将公钥公布给乙方，将私钥保留；双方约定数据加密算法；乙方通过甲方公钥构建密钥对儿，将公钥公布给甲方，将私钥保留。</li>
 * <li>甲方使用私钥、乙方公钥、约定数据加密算法构建本地密钥，然后通过本地密钥加密数据，发送给乙方加密后的数据；乙方使用私钥、甲方公钥、
 * 约定数据加密算法构建本地密钥，然后通过本地密钥对数据解密。</li>
 * <li>乙方使用私钥、甲方公钥、约定数据加密算法构建本地密钥，然后通过本地密钥加密数据，发送给甲方加密后的数据；甲方使用私钥、乙方公钥、
 * 约定数据加密算法构建本地密钥，然后通过本地密钥对数据解密。</li>
 * </ol>
 * 
 * @author jianggujin
 * 
 */
public class JDH {

   public final static String ALGORITHM = "DH";

   /**
    * 对称算法
    * 
    * @author jianggujin
    *
    */
   public static enum JDHSymmetricalAlgorithm {
      DES, DESede;
      public String getName() {
         return this.name();
      }
   }

   /**
    * 初始化甲方密钥
    * 
    * @return
    */
   public static KeyPair initPartyAKey() {
      return initPartyAKey(1024);
   }

   /**
    * 初始化甲方密钥
    * 
    * @param keySize
    * @return
    */
   public static KeyPair initPartyAKey(int keySize) {
      return JCodecUtils.initKey(ALGORITHM, keySize);
   }

   /**
    * 初始化乙方密钥
    * 
    * @param partyAPublicKey
    *           甲方公钥
    * @return
    */
   public static KeyPair initPartyBKey(byte[] partyAPublicKey) {
      PublicKey pubKey;
      try {
         X509EncodedKeySpec x509KeySpec = new X509EncodedKeySpec(partyAPublicKey);
         KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM);
         pubKey = keyFactory.generatePublic(x509KeySpec);
      } catch (Exception e) {
         throw new JCodecException(e);
      }
      return initPartyBKey(pubKey);
   }

   /**
    * 初始化乙方密钥
    * 
    * @param partyAKeyPair
    * @return
    */
   public static KeyPair initPartyBKey(KeyPair partyAKeyPair) {
      return initPartyBKey(partyAKeyPair.getPublic());
   }

   /**
    * 初始化乙方密钥
    * 
    * @param partyAPublicKey
    *           甲方公钥
    * @return
    */
   public static KeyPair initPartyBKey(PublicKey partyAPublicKey) {
      try {
         // 由甲方公钥构建乙方密钥
         DHParameterSpec dhParamSpec = ((DHPublicKey) partyAPublicKey).getParams();
         KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(ALGORITHM);
         keyPairGenerator.initialize(dhParamSpec);

         return keyPairGenerator.generateKeyPair();
      } catch (Exception e) {
         throw new JCodecException(e);
      }
   }

   /**
    * 加密
    * 
    * @param data
    *           加密数据
    * @param privateKey
    *           己方私钥
    * @param publicKey
    *           对方公钥
    * @param symmetricalAlgorithm
    *           对称算法
    * @return
    */
   public static byte[] encrypt(byte[] data, byte[] privateKey, byte[] publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return encrypt(data, privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   /**
    * 加密
    * 
    * @param data
    *           加密数据
    * @param privateKey
    *           己方私钥
    * @param publicKey
    *           对方公钥
    * @param symmetricalAlgorithm
    *           对称算法
    * @return
    */
   public static byte[] encrypt(byte[] data, byte[] privateKey, byte[] publicKey, String symmetricalAlgorithm) {
      Cipher cipher = getCipher(privateKey, publicKey, symmetricalAlgorithm, Cipher.ENCRYPT_MODE);
      return JCodecUtils.doFinal(data, cipher);
   }

   /**
    * 加密
    * 
    * @param data
    * @param privateKey
    * @param publicKey
    * @param symmetricalAlgorithm
    * @return
    */
   public static byte[] encrypt(byte[] data, PrivateKey privateKey, PublicKey publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return encrypt(data, privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   /**
    * 加密
    * 
    * @param data
    * @param privateKey
    * @param publicKey
    * @param symmetricalAlgorithm
    * @return
    */
   public static byte[] encrypt(byte[] data, PrivateKey privateKey, PublicKey publicKey, String symmetricalAlgorithm) {
      Cipher cipher = getCipher(privateKey, publicKey, symmetricalAlgorithm, Cipher.ENCRYPT_MODE);
      return JCodecUtils.doFinal(data, cipher);
   }

   public static OutputStream wrap(OutputStream out, byte[] privateKey, byte[] publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return wrap(out, privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   public static OutputStream wrap(OutputStream out, byte[] privateKey, byte[] publicKey, String symmetricalAlgorithm) {
      Cipher cipher = getCipher(privateKey, publicKey, symmetricalAlgorithm, Cipher.ENCRYPT_MODE);
      return new JCipherOutputStream(cipher, out);
   }

   public static OutputStream wrap(OutputStream out, PrivateKey privateKey, PublicKey publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return wrap(out, privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   public static OutputStream wrap(OutputStream out, PrivateKey privateKey, PublicKey publicKey,
         String symmetricalAlgorithm) {
      Cipher cipher = getCipher(privateKey, publicKey, symmetricalAlgorithm, Cipher.ENCRYPT_MODE);
      return new JCipherOutputStream(cipher, out);
   }

   /**
    * 解密
    * 
    * @param data
    *           解密数据
    * @param privateKey
    *           己方私钥
    * @param publicKey
    *           对方公钥
    * @param symmetricalAlgorithm
    *           对称算法
    * @return
    */
   public static byte[] decrypt(byte[] data, byte[] privateKey, byte[] publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return decrypt(data, privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   /**
    * 解密
    * 
    * @param data
    *           解密数据
    * @param privateKey
    *           己方私钥
    * @param publicKey
    *           对方公钥
    * @param symmetricalAlgorithm
    *           对称算法
    * @return
    */
   public static byte[] decrypt(byte[] data, byte[] privateKey, byte[] publicKey, String symmetricalAlgorithm) {
      Cipher cipher = getCipher(privateKey, publicKey, symmetricalAlgorithm, Cipher.DECRYPT_MODE);
      return JCodecUtils.doFinal(data, cipher);
   }

   /**
    * 解密
    * 
    * @param data
    * @param privateKey
    * @param publicKey
    * @param symmetricalAlgorithm
    * @return
    */
   public static byte[] decrypt(byte[] data, PrivateKey privateKey, PublicKey publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return decrypt(data, privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   /**
    * 解密
    * 
    * @param data
    * @param privateKey
    * @param publicKey
    * @param symmetricalAlgorithm
    * @return
    */
   public static byte[] decrypt(byte[] data, PrivateKey privateKey, PublicKey publicKey, String symmetricalAlgorithm) {
      Cipher cipher = getCipher(privateKey, publicKey, symmetricalAlgorithm, Cipher.DECRYPT_MODE);
      return JCodecUtils.doFinal(data, cipher);
   }

   public static InputStream wrap(InputStream in, byte[] privateKey, byte[] publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return wrap(in, privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   public static InputStream wrap(InputStream in, byte[] privateKey, byte[] publicKey, String symmetricalAlgorithm) {
      Cipher cipher = getCipher(privateKey, publicKey, symmetricalAlgorithm, Cipher.DECRYPT_MODE);
      return new JCipherInputStream(cipher, in);
   }

   public static InputStream wrap(InputStream in, PrivateKey privateKey, PublicKey publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return wrap(in, privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   public static InputStream wrap(InputStream in, PrivateKey privateKey, PublicKey publicKey,
         String symmetricalAlgorithm) {
      Cipher cipher = getCipher(privateKey, publicKey, symmetricalAlgorithm, Cipher.DECRYPT_MODE);
      return new JCipherInputStream(cipher, in);
   }

   public static Cipher getCipher(byte[] privateKey, byte[] publicKey, JDHSymmetricalAlgorithm symmetricalAlgorithm,
         int opmode) {
      return getCipher(privateKey, publicKey, symmetricalAlgorithm.getName(), opmode);
   }

   public static Cipher getCipher(byte[] privateKey, byte[] publicKey, String symmetricalAlgorithm, int opmode) {
      SecretKey secretKey = getSecretKey(privateKey, publicKey, symmetricalAlgorithm);
      return getCipher(secretKey, opmode);
   }

   public static Cipher getCipher(PrivateKey privateKey, PublicKey publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm, int opmode) {
      return getCipher(privateKey, publicKey, symmetricalAlgorithm.getName(), opmode);
   }

   public static Cipher getCipher(PrivateKey privateKey, PublicKey publicKey, String symmetricalAlgorithm, int opmode) {
      SecretKey secretKey = getSecretKey(privateKey, publicKey, symmetricalAlgorithm);
      return getCipher(secretKey, opmode);
   }

   public static Cipher getCipher(SecretKey secretKey, int opmode) {
      JCodecUtils.checkOpMode(opmode);
      try {
         Cipher cipher = Cipher.getInstance(secretKey.getAlgorithm());
         cipher.init(opmode, secretKey);
         return cipher;
      } catch (Exception e) {
         throw new JCodecException(e);
      }
   }

   /**
    * 获得密钥
    * 
    * @param privateKey
    *           己方私钥
    * @param publicKey
    *           对方公钥
    * @param symmetricalAlgorithm
    *           对称算法
    * @return
    */
   public static SecretKey getSecretKey(byte[] privateKey, byte[] publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return getSecretKey(privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   /**
    * 获得密钥
    * 
    * @param privateKey
    *           己方私钥
    * @param publicKey
    *           对方公钥
    * @param symmetricalAlgorithm
    *           对称算法
    * @return
    */
   public static SecretKey getSecretKey(byte[] privateKey, byte[] publicKey, String symmetricalAlgorithm) {
      try {
         KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM);
         X509EncodedKeySpec x509KeySpec = new X509EncodedKeySpec(publicKey);
         PublicKey pubKey = keyFactory.generatePublic(x509KeySpec);

         PKCS8EncodedKeySpec pkcs8KeySpec = new PKCS8EncodedKeySpec(privateKey);
         PrivateKey priKey = keyFactory.generatePrivate(pkcs8KeySpec);

         return getSecretKey(priKey, pubKey, symmetricalAlgorithm);
      } catch (Exception e) {
         throw new JCodecException(e);
      }
   }

   public static SecretKey getSecretKey(PrivateKey privateKey, PublicKey publicKey,
         JDHSymmetricalAlgorithm symmetricalAlgorithm) {
      return getSecretKey(privateKey, publicKey, symmetricalAlgorithm.getName());
   }

   public static SecretKey getSecretKey(PrivateKey privateKey, PublicKey publicKey, String symmetricalAlgorithm) {
      try {
         KeyAgreement keyAgree = KeyAgreement.getInstance(ALGORITHM);
         keyAgree.init(privateKey);
         keyAgree.doPhase(publicKey, true);

         return keyAgree.generateSecret(symmetricalAlgorithm);
      } catch (Exception e) {
         throw new JCodecException(e);
      }
   }

}
